###########################################################
#### figure_3.R                                        ####
#### Generates Figure 3 and accompanying descriptives  ####
###########################################################

tab_in = import(here("Data","var_data.csv"))
vars_fact = tab_in |> 
  filter(is_fact == 1) |> 
  distinct(var) |> 
  pull(var)

vars_fact = c(vars_fact, "state", "COUNTYFIP")

#### Random Forest ####
data_rf = data|> 
  select(affpol, urbanicity2, bidenvote_cnty, efi_cnty, bidenvote,
         efi_state, divided, professional, mds1, mds2, h_diffs,
         index, age, gender, race_cat, pid, pew_bornagain, gdp_per_cap, unemployment,
         general_trust, institutional_corruption, institutional_response, vote_importance, pride,
         fair_treatment, strong, newsint, state, COUNTYFIP) |> 
  mutate(across(all_of(vars_fact), as.factor)) |> 
  filter(!is.na(affpol)) |> 
  drop_na()

set.seed(120)
data_split = initial_split(data_rf, prop = .8)
data_train = training(data_split)
data_test = testing(data_split)

# Define Recipes
f_indiv_st  = as.formula("affpol ~ bidenvote + efi_state + divided + professional + mds1 + mds2 + h_diffs +
               index + age + gender + race_cat + pid + pew_bornagain + gdp_per_cap + unemployment")
f_indiv_geo = as.formula("affpol ~ urbanicity2 + bidenvote_cnty + efi_cnty + bidenvote +
               efi_state + divided + professional + mds1 + mds2 + h_diffs +
               index + age + gender + race_cat + pid + pew_bornagain + gdp_per_cap + unemployment")
f_indiv     = as.formula("affpol ~ age + gender + race_cat + pid + pew_bornagain")
f_st        = as.formula("affpol ~ bidenvote + efi_state + divided + professional + 
                         mds1 + mds2 + h_diffs + index + gdp_per_cap + unemployment")
f_geo       = as.formula("affpol ~ urbanicity2 + bidenvote_cnty + efi_cnty + bidenvote +
               efi_state + divided + professional + mds1 + mds2 + h_diffs + index + gdp_per_cap + unemployment")
f_psych     = as.formula('affpol ~ pid + general_trust + institutional_corruption + institutional_response + 
                         vote_importance + pride + fair_treatment + strong + newsint')

f_list = list(f_indiv_st, f_indiv_geo, f_indiv, f_st, f_geo, f_psych)
set_vars = c("Individual + State", "Everything (Non-Psych)", "Individual Only",
             "State Only", "All Geographic", "Psychological")

# Set cross-validation for tuning
set.seed(101)
folds = vfold_cv(data_train)
registerDoParallel(cores = 4)

# fit each model, extract result
rf_output = map_dfr(1:6, \(x){
  print(set_vars[x])
  # Set recipe
  rec_rf = recipe(f_list[[x]], data = data_train) |> 
    step_normalize(all_numeric_predictors()) |> 
    step_dummy(all_factor_predictors())
  # Set Model
  rf_model = rand_forest(mtry = tune(),
                         min_n = tune(),
                         trees = 1000) |> 
    set_engine("ranger",
               num.threads = parallel::detectCores(),
               importance = "impurity",
               seed = 613) |> 
    set_mode("regression")
  
  # Set workflow(s)
  rf_wflow = workflow() |> 
    add_model(rf_model) |> 
    add_recipe(rec_rf)
  
  # Tune 
  set.seed(345)
  tune_res = tune_grid(
    rf_wflow,
    resamples = folds,
    grid = 10
  )
  # Finalize
  best_rmse = select_best(tune_res, "rmse")
  final_rf = finalize_model(rf_model, best_rmse)
  final_wf = workflow() |> 
    add_model(final_rf) |> 
    add_recipe(rec_rf)
  final_res = final_wf |> 
    last_fit(data_split)
  # Evaluate
  collect_metrics(final_res) |> 
    mutate(set = set_vars[x])
})

export(rf_output, here("Output","rf_output.rds"))

#### Plot ####

fig_3 = rf_output |> 
  add_case(.metric = "rmse", set = 'Null',
           .estimate = sqrt(sum((data_test$affpol - mean(data_train$affpol))^2)/nrow(data_test))) |> 
  filter(.metric == "rmse") |> 
  ggplot(aes(x = .estimate, y = reorder(set, .estimate))) +
  geom_bar(stat = 'identity') +
  geom_label(aes(x = .estimate - 2, label = round(.estimate, 1))) +
  labs(x = "RMSE", y = "Predictor Set") +
  theme_prl()

print(fig_3)

ggsave(here("Plots","figure_3.pdf"), fig_3,
       dpi = 600, units = "in", height = 4, width = 6)
